Skip to content

Conversation

@aolemila
Copy link
Collaborator

@aolemila aolemila commented Dec 2, 2025

Resolves #825.

  1. Add codes in scripts/grpo_demo_llama3_qwen2.py to run LoRA.
  2. Add the sglang_jax_lora_test.py to ensure update_params works, and put it into tpu-tests.yml. verify_update_params will be executed when VERIFY_UPDATE_PARAMS_KEY is configured.
  3. Add more fields for SGLangJax in RolloutConfig.
  4. Pass the following tests. Environment: TPU-v6e-4.

Test1: Run verification of update_params

JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python3 tests/generate/sglang_jax_lora_test.py.
image

Test2: Run scripts/grpo_demo_llama3_qwen2.py without LoRA

JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python3 scripts/grpo_demo_llama3_qwen2.py --num-batches 2 --num-test-batches 1 --root-dir=/home/gcpuser/aolemila --rollout-engine sglang_jax.
Pasted Graphic 20

Test3: Run scripts/grpo_demo_llama3_qwen2.py with LoRA

JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python3 scripts/grpo_demo_llama3_qwen2.py --num-batches 2 --num-test-batches 1 --root-dir=/home/gcpuser/aolemila --rollout-engine sglang_jax --enable-lora --lora-target-modules all.
Pasted Graphic 21

Reference

Colab Notebook

Checklist

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and all unit tests pass.
  • I have added all appropriate doc-strings/documentation.
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have signed the Contributor License Agreement.
  • I have followed Contribution Guidelines.

@aolemila aolemila force-pushed the feat/add-lora-for-sglangjax branch from 32a06fd to 3af36c3 Compare December 3, 2025 09:54
@aolemila aolemila changed the title [WIP] Feat/add lora for sglangjax [Feature] Feat/add lora for sglangjax Dec 3, 2025
@aolemila aolemila force-pushed the feat/add-lora-for-sglangjax branch from 3ea01a4 to a56ef3a Compare December 4, 2025 06:46
@aolemila aolemila force-pushed the feat/add-lora-for-sglangjax branch from 3dfcb05 to bb7bd21 Compare December 5, 2025 02:46
@wang2yn84
Copy link
Collaborator

Hi @aolemila , thank you for the PR! Can you rebase to head and resolve the conflicts? We've removed the sglang script so it should merge into our main script. And please squash the commits.

Copy link
Collaborator

@wang2yn84 wang2yn84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your PR! Left some comments.

@aolemila
Copy link
Collaborator Author

Hi @wang2yn84 , thanks for your reply. I will rebase the main and modify codes according to your advice.

@aolemila
Copy link
Collaborator Author

I am rerunning scripts and fix new problems I meet.

@aolemila aolemila force-pushed the feat/add-lora-for-sglangjax branch from bb7bd21 to e32615c Compare December 25, 2025 12:44
@aolemila aolemila changed the title [Feature] Feat/add lora for sglangjax [WIP] Feat/add lora for sglangjax Dec 25, 2025
@aolemila aolemila force-pushed the feat/add-lora-for-sglangjax branch from e32615c to 3f755c5 Compare December 26, 2025 03:25
@aolemila aolemila changed the title [WIP] Feat/add lora for sglangjax Feat/add lora for sglangjax Dec 26, 2025
@aolemila aolemila self-assigned this Dec 26, 2025
@aolemila aolemila force-pushed the feat/add-lora-for-sglangjax branch from 3f755c5 to 874bfbc Compare December 26, 2025 04:26
@aolemila
Copy link
Collaborator Author

Hi, @wang2yn84 . I have updated codes according to your suggestions. In addition to modifications, I have passed three test cases. You can see more details in PR descriptions.

  • Test1: Run verification of update_params
  • Test2: Run scripts/grpo_demo_llama3_qwen2.py without LoRA
  • Test3: Run scripts/grpo_demo_llama3_qwen2.py with LoRA

new_model_state_leaves, _ = jax.tree_util.tree_flatten(new_state)
self._model_runner.model_state_leaves = new_model_state_leaves

flatten_src_to_tgt_module_name = os.getenv(VERIFY_UPDATE_PARAMS_KEY, None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part of validation should belong to the test instead of the production code?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed here: commit.

@@ -0,0 +1,371 @@
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I didn't look into the details in the last round of review. Seems this integration test is quite heavy, using 3B model to run the whole GRPO workflow. Such test better go to nightly regression. In CI, can we have some lightweight validation such just test update_param?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I add python scripts/grpo_demo_llama3_qwen2.py --num-batches 2 --num-test-batches 1 --root-dir=/home/gcpuser/aolemila --rollout-engine sglang_jax --enable-lora --lora-target-modules all in tpu-nightly-regression.yml to run LoRA case. For tests/generate/sglang_jax_lora_test.py in tpu-tests.yml, I will simplify it and make it more lightweight.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed here: commit.

return text.split("####")[1].strip()


def download_kaggle_dataset(target_dir="./data/gsm8k"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we leverage the existing API? We have dataset loading and get lora model APIs. No need to recreate these functions again. If the existing API is not sufficient, say there is no other dataset support, can you help improve the existing API maybe in a separate PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. These codes are based on old version scripts/grpo_demo_llama3_qwen2.py, and maybe they are outdated. I will follow the latest scripts/grpo_demo_llama3_qwen2.py to use recommended APIs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not used in simpified version. Fixed here: commit.

# List of batch sizes buckets for jax jit
rollout_sglang_jax_precompile_bs_paddings: Optional[List] = None
# Whether to use lora
rollout_sglang_jax_enable_static_lora: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this suppose to be True? IIUC, the way Tunix uses Lora is static, cuz we don't require to select from multiple lora and change on the fly.

Copy link
Collaborator Author

@aolemila aolemila Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is another case that you may not use LoRA, so setting it to True ensures that you know you are using LoRA. And SGLangJax will replace the base_layer and initialize the zero buffer if you enable_static_lora. There are a few differences compared with disabling static lora.

if (
mappings is None
or not enable_static_lora
or lora_target_modules is None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"not lora_target_modules" should have the same effect as "or lora_target_modules is None or len(lora_target_modules) == 0"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

self.engine = Engine(**self.args)

self.mappings = config.mapping_config.to_hf_mappings
self.to_hf_key_mappings = config.mapping_config.to_hf_mappings
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundant

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

args["enable_deterministic_sampling"] = True
if config.init_with_random_weights:
args["load_format"] = "dummy"
args["disable_radix_cache"] = config.disable_radix_cache
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider put checkers into a separate function

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

)


def update_hf_key_mappings_with_lora(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably need to move this function to top of the file, other our internal might complain about not be able to find the function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@aolemila
Copy link
Collaborator Author

I run python scripts/grpo_demo_llama3_qwen2.py --num-batches 2 --num-test-batches 1 --root-dir=/home/gcpuser/aolemila --rollout-engine sglang_jax --enable-lora --lora-target-modules all.

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature] Support LoRA For SGLangJax Rollout

3 participants